{
 "cells": [
  {
   "cell_type": "raw",
   "id": "c6e3d313-4b94-4b76-932c-fd67a999361a",
   "metadata": {},
   "source": [
    "---\n",
    "title: metrics\n",
    "subtitle: Image similarity metrics and geodesic distances for camera poses\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a3d1f8e-5b89-41b3-889d-2b054f63c125",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a22f019-d10f-4c94-aa9e-fc6d32f74f4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a452921-86fb-49c1-a3c6-749fad621a55",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from diffdrr.metrics import (\n",
    "    GradientNormalizedCrossCorrelation2d,\n",
    "    MultiscaleNormalizedCrossCorrelation2d,\n",
    "    NormalizedCrossCorrelation2d,\n",
    ")\n",
    "from torchmetrics import Metric"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "870e3dc8-fc8d-46ef-9204-66847c1cf4df",
   "metadata": {},
   "source": [
    "## Image similarity metrics\n",
    "\n",
    "Used to quantify the similarity between ground truth X-rays ($\\mathbf I$) and synthetic X-rays generated from estimated camera poses ($\\hat{\\mathbf I}$). If a metric is differentiable, it can be used to optimize camera poses with `DiffDRR`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d5b1780-fa0a-4d81-8103-40572c938431",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class CustomMetric(Metric):\n",
    "    is_differentiable: True\n",
    "\n",
    "    def __init__(self, LossClass, **kwargs):\n",
    "        super().__init__()\n",
    "        self.lossfn = LossClass(**kwargs)\n",
    "        self.add_state(\"loss\", default=torch.tensor(0.0), dist_reduce_fx=\"sum\")\n",
    "        self.add_state(\"count\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n",
    "\n",
    "    def update(self, preds, target):\n",
    "        self.loss += self.lossfn(preds, target).sum()\n",
    "        self.count += len(preds)\n",
    "\n",
    "    def compute(self):\n",
    "        return self.loss.float() / self.count"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbfb768c-3745-47b6-97c8-760debc83fa6",
   "metadata": {},
   "source": [
    "`NCC` and `GradNCC` are originally implemented in [`diffdrr.metrics`](https://github.com/eigenvivek/DiffDRR/blob/main/notebooks/api/05_metrics.ipynb).\n",
    "`DiffPose` provides `torchmetrics` wrappers for these functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc49ea90-0841-4ac5-9b17-2faab1191d10",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class NormalizedCrossCorrelation(CustomMetric):\n",
    "    \"\"\"`torchmetric` wrapper for NCC.\"\"\"\n",
    "\n",
    "    higher_is_better: True\n",
    "\n",
    "    def __init__(self, patch_size=None):\n",
    "        super().__init__(NormalizedCrossCorrelation2d, patch_size=patch_size)\n",
    "\n",
    "\n",
    "class MultiscaleNormalizedCrossCorrelation(CustomMetric):\n",
    "    \"\"\"`torchmetric` wrapper for Multiscale NCC.\"\"\"\n",
    "\n",
    "    higher_is_better: True\n",
    "\n",
    "    def __init__(self, patch_sizes, patch_weights):\n",
    "        super().__init__(\n",
    "            MultiscaleNormalizedCrossCorrelation2d,\n",
    "            patch_sizes=patch_sizes,\n",
    "            patch_weights=patch_weights,\n",
    "        )\n",
    "\n",
    "\n",
    "class GradientNormalizedCrossCorrelation(CustomMetric):\n",
    "    \"\"\"`torchmetric` wrapper for GradNCC.\"\"\"\n",
    "\n",
    "    higher_is_better: True\n",
    "\n",
    "    def __init__(self, patch_size=None):\n",
    "        super().__init__(GradientNormalizedCrossCorrelation2d, patch_size=patch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f394f290-bac4-4831-bec6-5ac88dcdd914",
   "metadata": {},
   "source": [
    "## Geodesic distances for SO(3) and SE(3)\n",
    "\n",
    "One can define geodesic pseudo-distances on SO(3) and SE(3).[^1] This let's us measure registration error (in radians and millimeters, respectively) on poses, rather than needed to compute the projection of fiducials.\n",
    "\n",
    "- **For SO(3)**, the geodesic distance between two rotation matrices $\\mathbf R_A$ and $\\mathbf R_B$ is\n",
    "\\begin{equation}\n",
    "    d_\\theta(\\mathbf R_A, \\mathbf R_B; r) = r \\left| \\arccos \\left( \\frac{\\mathrm{trace}(\\mathbf R_A^* \\mathbf R_B) - 1}{2} \\right ) \\right| \\,,\n",
    "\\end{equation}\n",
    "where $r$, the source-to-detector radius, is used to convert radians to millimeters.\n",
    "\n",
    "- **For SE(3)**, we decompose the transformation matrix into a rotation and a translation, i.e., $\\mathbf T = (\\mathbf R, \\mathbf t)$.\n",
    "Then, we compute the geodesic on translations (just Euclidean distance),\n",
    "\\begin{equation}\n",
    "    d_t(\\mathbf t_A, \\mathbf t_B) = \\| \\mathbf t_A - \\mathbf t_B \\|_2 \\,.\n",
    "\\end{equation}\n",
    "Finally, we compute the *double geodesic* on the rotations and translations:\n",
    "\\begin{equation}\n",
    "    d(\\mathbf T_A, \\mathbf T_B) = \\sqrt{d_\\theta(\\mathbf R_A, \\mathbf R_B)^2 + d_t(\\mathbf t_A, \\mathbf t_B)^2} \\,.\n",
    "\\end{equation}\n",
    "\n",
    "[^1]: [https://vnav.mit.edu/material/04-05-LieGroups-notes.pdf](https://vnav.mit.edu/material/04-05-LieGroups-notes.pdf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6400a4b1-417e-48b5-aa7f-207c8cdf893d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import torch\n",
    "from beartype import beartype\n",
    "from diffdrr.utils import (\n",
    "    convert,\n",
    "    so3_log_map,\n",
    "    so3_relative_angle,\n",
    "    so3_rotation_angle,\n",
    "    standardize_quaternion,\n",
    ")\n",
    "from jaxtyping import Float, jaxtyped\n",
    "\n",
    "from diffpose.calibration import RigidTransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ff308dc-4807-46dd-bd10-ef9dca35c4c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class GeodesicSO3(torch.nn.Module):\n",
    "    \"\"\"Calculate the angular distance between two rotations in SO(3).\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    @jaxtyped(typechecker=beartype)\n",
    "    def forward(\n",
    "        self,\n",
    "        pose_1: RigidTransform,\n",
    "        pose_2: RigidTransform,\n",
    "    ) -> Float[torch.Tensor, \"b\"]:\n",
    "        r1 = pose_1.get_rotation()\n",
    "        r2 = pose_2.get_rotation()\n",
    "        rdiff = r1 @ r2.transpose(-1, -2)\n",
    "        return so3_log_map(rdiff).norm(dim=-1)\n",
    "\n",
    "\n",
    "class GeodesicTranslation(torch.nn.Module):\n",
    "    \"\"\"Calculate the angular distance between two rotations in SO(3).\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    @jaxtyped(typechecker=beartype)\n",
    "    def forward(\n",
    "        self,\n",
    "        pose_1: RigidTransform,\n",
    "        pose_2: RigidTransform,\n",
    "    ) -> Float[torch.Tensor, \"b\"]:\n",
    "        t1 = pose_1.get_translation()\n",
    "        t2 = pose_2.get_translation()\n",
    "        return (t1 - t2).norm(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ad99b83-9759-4909-9930-598cf9c8433a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class GeodesicSE3(torch.nn.Module):\n",
    "    \"\"\"Calculate the distance between transforms in the log-space of SE(3).\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    @jaxtyped(typechecker=beartype)\n",
    "    def forward(\n",
    "        self,\n",
    "        pose_1: RigidTransform,\n",
    "        pose_2: RigidTransform,\n",
    "    ) -> Float[torch.Tensor, \"b\"]:\n",
    "        return pose_2.compose(pose_1.inverse()).get_se3_log().norm(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "099c7c47-0d3b-4c4b-8ce9-c50051fa115d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@beartype\n",
    "class DoubleGeodesic(torch.nn.Module):\n",
    "    \"\"\"Calculate the angular and translational geodesics between two SE(3) transformation matrices.\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        sdr: float,  # Source-to-detector radius\n",
    "        eps: float = 1e-4,  # Avoid overflows in sqrt\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.sdr = sdr\n",
    "        self.eps = eps\n",
    "\n",
    "        self.rotation = GeodesicSO3()\n",
    "        self.translation = GeodesicTranslation()\n",
    "\n",
    "    @jaxtyped(typechecker=beartype)\n",
    "    def forward(self, pose_1: RigidTransform, pose_2: RigidTransform):\n",
    "        angular_geodesic = self.sdr * self.rotation(pose_1, pose_2)\n",
    "        translation_geodesic = self.translation(pose_1, pose_2)\n",
    "        double_geodesic = (\n",
    "            (angular_geodesic).square() + translation_geodesic.square() + self.eps\n",
    "        ).sqrt()\n",
    "        return angular_geodesic, translation_geodesic, double_geodesic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "708e9b0b-3fd1-4671-b9a8-a993ced0d187",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.])\n",
      "tensor([0.1000])\n"
     ]
    }
   ],
   "source": [
    "# SO(3) distance\n",
    "geodesic_so3 = GeodesicSO3()\n",
    "\n",
    "pose_1 = RigidTransform(\n",
    "    torch.tensor([[0.1, 1.0, torch.pi]]),\n",
    "    torch.ones(1, 3),\n",
    "    parameterization=\"euler_angles\",\n",
    "    convention=\"ZYX\",\n",
    ")\n",
    "pose_2 = RigidTransform(\n",
    "    torch.tensor([[0.1, 1.0, torch.pi]]),\n",
    "    torch.ones(1, 3),\n",
    "    parameterization=\"euler_angles\",\n",
    "    convention=\"ZYX\",\n",
    ")\n",
    "\n",
    "print(geodesic_so3(pose_1, pose_2))  # Angular distance in radians\n",
    "\n",
    "pose_1 = RigidTransform(\n",
    "    torch.tensor([[0.1, 1.0, torch.pi]]),\n",
    "    torch.ones(1, 3),\n",
    "    parameterization=\"euler_angles\",\n",
    "    convention=\"ZYX\",\n",
    ")\n",
    "pose_2 = RigidTransform(\n",
    "    torch.tensor([[0.1, 1.1, torch.pi]]),\n",
    "    torch.ones(1, 3),\n",
    "    parameterization=\"euler_angles\",\n",
    "    convention=\"ZYX\",\n",
    ")\n",
    "\n",
    "print(geodesic_so3(pose_1, pose_2))  # Angular distance in radians"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0ff1bce-f8a1-4c40-bb2f-4e9dfa87ac17",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1.7355])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# SE(3) distance\n",
    "geodesic_se3 = GeodesicSE3()\n",
    "\n",
    "pose_1 = RigidTransform(\n",
    "    torch.tensor([[0.1, 1.0, torch.pi]]),\n",
    "    torch.ones(1, 3),\n",
    "    parameterization=\"euler_angles\",\n",
    "    convention=\"ZYX\",\n",
    ")\n",
    "pose_2 = RigidTransform(\n",
    "    torch.tensor([[0.1, 1.1, torch.pi]]),\n",
    "    torch.zeros(1, 3),\n",
    "    parameterization=\"euler_angles\",\n",
    "    convention=\"ZYX\",\n",
    ")\n",
    "\n",
    "geodesic_se3(pose_1, pose_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b69987b-98dc-431e-b885-b9e992e0e91a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([51.0000]), tensor([1.7321]), tensor([51.0294]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Angular distance and translational distance both in mm\n",
    "double_geodesic = DoubleGeodesic(1020 / 2)\n",
    "\n",
    "pose_1 = RigidTransform(\n",
    "    torch.tensor([[0.1, 1.0, torch.pi]]),\n",
    "    torch.ones(1, 3),\n",
    "    parameterization=\"euler_angles\",\n",
    "    convention=\"ZYX\",\n",
    ")\n",
    "pose_2 = RigidTransform(\n",
    "    torch.tensor([[0.1, 1.1, torch.pi]]),\n",
    "    torch.zeros(1, 3),\n",
    "    parameterization=\"euler_angles\",\n",
    "    convention=\"ZYX\",\n",
    ")\n",
    "\n",
    "double_geodesic(pose_1, pose_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba77dd5d-cd83-4a1e-8520-5d27d9b972fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import nbdev\n",
    "\n",
    "nbdev.nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa4fbd2b-75ce-48d0-90ff-56ddb421a547",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
